# -*- coding: utf-8 -*- """ train_piston_v5_4_ultra_stable_bw_fast_dashboard.py -------------------------------------------------------- ✅ 黑白化輸出 + 工程圖對照 + 實時儀表板同步 ✅ 產生 piston_status.json + loss_plot.png + 預覽圖 ✅ 相容 PyTorch 2.6(含 checkpoint 安全載入) """ import os, sys, io, time, json, shutil, cv2, torch, numpy as np, GPUtil, matplotlib matplotlib.use("Agg") # ✅ 防止 headless 環境崩潰 import matplotlib.pyplot as plt from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from pytorch_msssim import ssim from pathlib import Path from tqdm import tqdm from datetime import datetime as dt from torch.amp import autocast, GradScaler # ✅ AMP # ---------- 路徑設定 ---------- ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TMP = os.path.join(ROOT, "tmp") os.makedirs(TMP, exist_ok=True) LOGF = os.path.join(TMP, "train_stdout.log") # ---------- UTF-8 與 Tee ---------- try: sys.stdout.reconfigure(encoding="utf-8", errors="replace") sys.stderr.reconfigure(encoding="utf-8", errors="replace") except Exception: pass class Tee(io.TextIOBase): def __init__(self, *streams): self.streams = streams def write(self, s): for st in self.streams: try: st.write(s) st.flush() except Exception: pass return len(s) def flush(self): for st in self.streams: try: st.flush() except Exception: pass _log_fp = open(LOGF, "a", encoding="utf-8") sys.stdout = Tee(sys.stdout, _log_fp) sys.stderr = Tee(sys.stderr, _log_fp) def log(msg: str): print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True) # ========================================================== # 🔹 模型架構 # ========================================================== def conv_block(in_ch, out_ch): return nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.SiLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.SiLU(inplace=True), ) class UNetUltra(nn.Module): def __init__(self, ch=1, base=32): super().__init__() self.enc1 = conv_block(ch, base) self.enc2 = conv_block(base, base * 2) self.enc3 = conv_block(base * 2, base * 4) self.enc4 = conv_block(base * 4, base * 8) self.pool = nn.MaxPool2d(2) self.bottleneck = conv_block(base * 8, base * 16) self.up4 = nn.ConvTranspose2d(base * 16, base * 8, 2, 2) self.dec4 = conv_block(base * 16, base * 8) self.up3 = nn.ConvTranspose2d(base * 8, base * 4, 2, 2) self.dec3 = conv_block(base * 8, base * 4) self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 2, 2) self.dec2 = conv_block(base * 4, base * 2) self.up1 = nn.ConvTranspose2d(base * 2, base, 2, 2) self.dec1 = conv_block(base * 2, base) self.outc = nn.Conv2d(base, 1, 1) def forward(self, x): e1 = self.enc1(x) p1 = self.pool(e1) e2 = self.enc2(p1) p2 = self.pool(e2) e3 = self.enc3(p2) p3 = self.pool(e3) e4 = self.enc4(p3) p4 = self.pool(e4) b = self.bottleneck(p4) u4 = self.up4(b) d4 = self.dec4(torch.cat([u4, e4], 1)) u3 = self.up3(d4) d3 = self.dec3(torch.cat([u3, e3], 1)) u2 = self.up2(d3) d2 = self.dec2(torch.cat([u2, e2], 1)) u1 = self.up1(d2) d1 = self.dec1(torch.cat([u1, e1], 1)) return torch.sigmoid(self.outc(d1)) # ========================================================== # 🔹 Dataset # ========================================================== class ImgDS(Dataset): def __init__(self, folder, size=(1088, 1920)): self.files = sorted([str(p) for p in Path(folder).glob("*.png")]) self.size = size def __len__(self): return len(self.files) def __getitem__(self, i): img = cv2.imread(self.files[i], cv2.IMREAD_GRAYSCALE) img = cv2.resize(img, self.size[::-1]) img = torch.from_numpy(img / 255.0).float().unsqueeze(0) return img, img # ========================================================== # 🔹 損失函式 # ========================================================== def safe_ssim(y, x): try: v = ssim(y, x, data_range=1.0, size_average=True) return torch.tensor(0.0, device=y.device) if torch.isnan(v) else v except Exception: return torch.tensor(0.0, device=y.device) def sobel_loss(pred, target): # 只在第一次呼叫時建立 kernel(放在 CPU) if not hasattr(sobel_loss, "sobel_x_cpu"): kx = torch.tensor( [[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32, ) ky = torch.tensor( [[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, ) sobel_loss.sobel_x_cpu = kx.view(1, 1, 3, 3) sobel_loss.sobel_y_cpu = ky.view(1, 1, 3, 3) sobel_x = sobel_loss.sobel_x_cpu.to(pred.device) sobel_y = sobel_loss.sobel_y_cpu.to(pred.device) e_pred = torch.abs(F.conv2d(pred, sobel_x, padding=1)) + torch.abs( F.conv2d(pred, sobel_y, padding=1) ) e_tgt = torch.abs(F.conv2d(target, sobel_x, padding=1)) + torch.abs( F.conv2d(target, sobel_y, padding=1) ) return F.mse_loss(e_pred / 4.0, e_tgt / 4.0) def laplace_loss(pred, target): """ Ultra+ Sharp:用 Laplacian 強化細線 / 字體邊緣 """ # 只在第一次呼叫時建立 kernel(放在 CPU) if not hasattr(laplace_loss, "laplace_cpu"): k = torch.tensor( [[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=torch.float32, ) laplace_loss.laplace_cpu = k.view(1, 1, 3, 3) k_lap = laplace_loss.laplace_cpu.to(pred.device) e_pred = torch.abs(F.conv2d(pred, k_lap, padding=1)) e_tgt = torch.abs(F.conv2d(target, k_lap, padding=1)) return F.mse_loss(e_pred / 4.0, e_tgt / 4.0) def contrast_loss(pred, target): """ Ultra+ Sharp:讓整張圖的黑白對比接近原圖 避免變灰、線條發白 """ # flatten 後取整體 mean / std,維持整體對比風格 pm, ps = pred.mean(), pred.std() tm, ts = target.mean(), target.std() return (pm - tm).pow(2) + (ps - ts).pow(2) def total_loss(y, x): """ Ultra+ Sharp 總損失: - MSE : 基本像素對齊 - Sobel : 邊緣方向一致(線條不糊) - Laplacian : 細線 / 字體邊緣更銳利 - SSIM(512) : 大結構一致 - Contrast : 黑白對比接近原圖 """ # 1) 基礎 MSE(在 full-res 上) mse = F.mse_loss(y, x) # 2) Sobel 邊緣(你原本就有) sbl = sobel_loss(y, x) # 3) Laplacian 高頻細節 lap = laplace_loss(y, x) # 4) SSIM 在 512x512 上算,大幅減少計算量 y_small = F.interpolate(y, size=(512, 512), mode="bilinear", align_corners=False) x_small = F.interpolate(x, size=(512, 512), mode="bilinear", align_corners=False) ssim_l = 1 - safe_ssim(y_small, x_small) # 5) 對比度一致性(避免圖變灰) cont = contrast_loss(y, x) # 權重可以之後微調,目前是偏「線條清晰」風格 return ( mse + 0.25 * sbl # 邊緣 + 0.20 * lap # 細線 / 字體 + 0.10 * ssim_l # 大結構 + 0.08 * cont # 對比度 ) # ========================================================== # 🔹 更新儀表板資料 # ========================================================== def update_status_json(epoch, loss, best_loss): try: gpus = GPUtil.getGPUs() gpu = gpus[0] if gpus else None gpu_info = { "gpu_util": int(gpu.load * 100) if gpu else 0, "mem_used": int(gpu.memoryUsed) if gpu else 0, "mem_total": int(gpu.memoryTotal) if gpu else 0, "temperature": int(gpu.temperature) if gpu else 0, "name": gpu.name if gpu else "CPU", } except Exception: gpu_info = { "gpu_util": 0, "mem_used": 0, "mem_total": 0, "temperature": 0, "name": "Unknown", } data = { "gpu": gpu_info, "train": {"epoch": epoch, "loss": loss, "best_val": best_loss}, "time": dt.now().strftime("%Y-%m-%d %H:%M:%S"), } json_path = Path(r"C:\xampp\htdocs\cs_ai\outputs\piston_status.json") json_path.parent.mkdir(parents=True, exist_ok=True) with open(json_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) # ========================================================== # 🔹 主訓練流程 # ========================================================== def main(): ROOT = Path(r"C:\xampp\htdocs\cs_ai") TRAIN = ROOT / "data/train/images" MODEL_DIR = ROOT / "ai_models" OUT_DIR = ROOT / "outputs" VIEW_DIR = OUT_DIR / "piston_views" for d in [MODEL_DIR, OUT_DIR, VIEW_DIR]: d.mkdir(exist_ok=True) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"🚀 CUDA: {torch.cuda.is_available()} | Device: {DEVICE}") # 模型瘦身:base=16 model = UNetUltra(base=16).to(DEVICE) opt = torch.optim.AdamW(model.parameters(), lr=5e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=5, min_lr=1e-7 ) scaler = GradScaler(device="cuda") if DEVICE.type == "cuda" else None ckpt_path = MODEL_DIR / "piston_v5_4_ultra_latest.ckpt" start_epoch, best_loss, patience_counter = 1, 9e9, 0 loss_hist = [] if ckpt_path.exists(): try: ckpt = torch.load(ckpt_path, map_location=DEVICE) model.load_state_dict(ckpt.get("model", {}), strict=False) opt.load_state_dict(ckpt.get("optimizer", {})) best_loss = ckpt.get("best_loss", 9e9) start_epoch = ckpt.get("epoch", 1) print(f"✅ 自動續訓成功,從第 {start_epoch} epoch 開始") except Exception as e: print(f"⚠️ 無法續訓,改從零開始: {e}") ds = ImgDS(TRAIN) num_workers = 4 dl = DataLoader( ds, batch_size=4, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, ) # ========================== # 🔁 訓練迴圈 # ========================== for ep in range(start_epoch, 500): model.train() losses = [] total_steps = len(dl) for i, (x, y) in enumerate( tqdm(dl, desc=f"Epoch {ep}/500", ncols=100), start=1, ): x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True) opt.zero_grad(set_to_none=True) if scaler is not None: with autocast(device_type="cuda", dtype=torch.float16): y_hat = model(x) loss = total_loss(y_hat, y) if torch.isnan(loss): continue scaler.scale(loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(opt) scaler.update() else: y_hat = model(x) loss = total_loss(y_hat, y) if torch.isnan(loss): continue loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() losses.append(loss.item()) # 每 20 step:清掉舊 log,寫入最新進度(給網頁顯示) if i % 50 == 0: # 1) 清除舊的 log 檔並只寫入這一筆最新紀錄 with open(LOGF, "w", encoding="utf-8") as f: f.write( f"[{time.strftime('%H:%M:%S')}] " f"Epoch {ep}/500 step {i}/{total_steps} loss={loss.item():.4f}\n" ) # 2) 也印到 console(Tee 會同步再寫一份到 log,不影響上面那行) print( f"[{time.strftime('%H:%M:%S')}] " f"Epoch {ep}/500 step {i}/{total_steps} loss={loss.item():.4f}", flush=True, ) if not losses: continue avg = float(np.mean(losses)) loss_hist.append(avg) scheduler.step(avg) print(f"✅ Epoch {ep} | loss={avg:.6f}") update_status_json(ep, avg, best_loss) # 儲存 Loss 曲線 plt.figure(figsize=(6, 3)) plt.plot(loss_hist) plt.grid(True) plt.tight_layout() plt.savefig(r"C:\xampp\htdocs\cs_ai\outputs\loss_plot.png") plt.close() # 儲存模型 if avg + 1e-6 < best_loss: best_loss = avg patience_counter = 0 torch.save( model.state_dict(), MODEL_DIR / f"piston_ultra_best_ep{ep}.pth" ) else: patience_counter += 1 torch.save( { "epoch": ep + 1, "model": model.state_dict(), "optimizer": opt.state_dict(), "best_loss": best_loss, }, ckpt_path, ) if patience_counter >= 10: print("🛑 Early stop triggered.") break if __name__ == "__main__": main()